// component.cpp
#define _WIN32_DCOM
#include <windows.h>
#include <iostream.h>
#include <conio.h>
#include "Component\component.h"
#include "registry.h"

CLSID CLSID_InsideCOMProxy = {0x10000004,0x0000,0x0000,{0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x01}};

long g_cComponents = 0;
long g_cServerLocks = 0;
HANDLE hFileMap, hStubEvent, hProxyEvent;

struct SumTransmit
{
	int x;
	int y;
	int sum;
};

class CInsideCOM : public ISum, public IMarshal
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void **ppv);

	// IMarshal
	HRESULT __stdcall GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, CLSID* pClsid);
	HRESULT __stdcall GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, DWORD* pSize);
	HRESULT __stdcall MarshalInterface(IStream* pStream, REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags);
	HRESULT __stdcall DisconnectObject(DWORD dwReserved);
	HRESULT __stdcall UnmarshalInterface(IStream* pStream, REFIID riid, void** ppv);
	HRESULT __stdcall ReleaseMarshalData(IStream* pStream);

	// ISum
	HRESULT __stdcall Sum(int x, int y, int *sum);

	CInsideCOM() : m_cRef(1) { g_cComponents++; }
	~CInsideCOM() { cout << "Component: CInsideCOM::~CInsideCOM()" << endl, g_cComponents--; }

private:
	long m_cRef;
};

CInsideCOM *g_pInsideCOM;

HRESULT CInsideCOM::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, CLSID* pClsid)
{
	// We only handle the local marshaling case

	// Under Windows NT/2000, MSHCTX_DIFFERENTMACHINE is returned in somes cases even when the 
	// connection is local.  The solution is to either comment out the if statement 
	// below or use the "CoGetClassObject client.cpp" file which uses CoGetClassObject
	// instead of CoCreateInstance.  This ensures that the correct MSHCTX_ flag is used.
	// This happens because when the client calls CoCreateInstance(), the SCM on the server 
	// is doing	the CreateInstance and the SCM has in the past had no way to tell the
	// server whether the client is in a different machine so the worst case
	// assumption is required.
	// This may have been corrected by the time you read this
	if(dwDestContext == MSHCTX_DIFFERENTMACHINE)
	{
		IMarshal* pMarshal;
		CoGetStandardMarshal(riid, (ISum*)pv, dwDestContext, pvDestContext, dwFlags, &pMarshal);
		HRESULT hr = pMarshal->GetUnmarshalClass(riid, pv, dwDestContext, pvDestContext, dwFlags, pClsid);
		pMarshal->Release();
		return hr;
	}
	*pClsid = CLSID_InsideCOMProxy;
	return S_OK;
}

HRESULT CInsideCOM::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, DWORD* pSize)
{
	// We only handle the local marshaling case

	// Under NT/2000, MSHCTX_DIFFERENTMACHINE is returned in somes cases even when the 
	// connection is local.  The solution is to either comment out the if statement 
	// below or use the "CoGetClassObject client.cpp" file which uses CoGetClassObject
	// instead of CoCreateInstance.  This ensures that the correct MSHCTX_ flag is used.
	// This happens because when the client calls CoCreateInstance(), the SCM on the server 
	// is doing	the CreateInstance and the SCM has in the past had no way to tell the
	// server whether the client is in a different machine so the worst case
	// assumption is required.
	// This may have been corrected by the time you read this
	if(dwDestContext == MSHCTX_DIFFERENTMACHINE)
	{
		IMarshal* pMarshal;
		CoGetStandardMarshal(riid, (ISum*)pv, dwDestContext, pvDestContext, dwFlags, &pMarshal);
		HRESULT hr = pMarshal->GetMarshalSizeMax(riid, pv, dwDestContext, pvDestContext, dwFlags, pSize);
		pMarshal->Release();
		return hr;
	}
 	*pSize = 255;
	return S_OK;
}

HRESULT CInsideCOM::MarshalInterface(IStream* pStream, REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags)
{
	// We only handle the local marshaling case

	// Under NT/2000, MSHCTX_DIFFERENTMACHINE is returned in somes cases even when the 
	// connection is local.  The solution is to either comment out the if statement 
	// below or use the "CoGetClassObject client.cpp" file which uses CoGetClassObject
	// instead of CoCreateInstance.  This ensures that the correct MSHCTX_ flag is used.
	// This happens because when the client calls CoCreateInstance(), the SCM on the server 
	// is doing	the CreateInstance and the SCM has in the past had no way to tell the
	// server whether the client is in a different machine so the worst case
	// assumption is required.
	// This may have been corrected by the time you read this
	if(dwDestContext == MSHCTX_DIFFERENTMACHINE)
	{
		IMarshal* pMarshal;
		CoGetStandardMarshal(riid, (ISum*)pv, dwDestContext, pvDestContext, dwFlags, &pMarshal);
		HRESULT hr = pMarshal->MarshalInterface(pStream, riid, pv, dwDestContext, pvDestContext, dwFlags);
		pMarshal->Release();
		return hr;
	}

	ULONG num_written;
	char* szFileMapName = "FileMap";
	char* szStubEventName = "StubEvent";
	char* szProxyEventName = "ProxyEvent";
	char buffer_to_write[255];

	AddRef();

	hFileMap = CreateFileMapping((HANDLE)0xFFFFFFFF, NULL, PAGE_READWRITE, 0, 255, szFileMapName);
	hStubEvent = CreateEvent(NULL, FALSE, FALSE, szStubEventName);
	hProxyEvent = CreateEvent(NULL, FALSE, FALSE, szProxyEventName);

	strcpy(buffer_to_write, szFileMapName);
	strcat(buffer_to_write, ",");
	strcat(buffer_to_write, szStubEventName);
	strcat(buffer_to_write, ",");
	strcat(buffer_to_write, szProxyEventName);

	cout << "Component: CInsideCOM::MarshalInterface() = " << buffer_to_write << endl;
	return pStream->Write(buffer_to_write, strlen(buffer_to_write)+1, &num_written);
}

HRESULT CInsideCOM::DisconnectObject(DWORD dwReserved)
    {
	cout << "DisconnectObject" << endl;
	CloseHandle(hFileMap);
	CloseHandle(hStubEvent);
	CloseHandle(hProxyEvent);
    return E_NOTIMPL;
    }

HRESULT CInsideCOM::ReleaseMarshalData(IStream* pStream)
    {
	cout << "ReleaseMarshalData" << endl;
    return S_OK;
    }

HRESULT CInsideCOM::UnmarshalInterface(IStream* pStream, REFIID riid, void** ppv)
    {
	cout << "UnmarshalInterface" << endl;
    return E_UNEXPECTED;
    }

ULONG CInsideCOM::AddRef()
{
	cout << "Component: CInsideCOM::AddRef() m_cRef = " << m_cRef + 1 << endl;
	return m_cRef++;
}

ULONG CInsideCOM::Release()
{
	cout << "Component: CInsideCOM::Release() m_cRef = " << m_cRef - 1 << endl;
	if(--m_cRef != 0)
		return m_cRef;
	delete this;
	return 0;
}

HRESULT CInsideCOM::QueryInterface(REFIID riid, void **ppv)
{
	if(riid == IID_IUnknown)
	{
		cout << "Component: CInsideCOM::QueryInterface() for IUnknown" << endl;
		*ppv = (ISum*)this;
	}
	else if(riid == IID_IMarshal)
	{
		cout << "Component: CInsideCOM::QueryInterface for IMarshal" << endl;
		*ppv = (IMarshal*)this;
	}
	else if(riid == IID_ISum)
	{
		cout << "Component: CInsideCOM::QueryInterface() for ISum" << endl;
		*ppv = (ISum*)this;
	}
	else
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

HRESULT CInsideCOM::Sum(int x, int y, int *sum)
{
	cout << "Component: CInsideCOM::Sum() " << x << " + " << y << " = " << x + y << endl;
	*sum = x + y;
	return S_OK;
}

class CFactory : public IClassFactory
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void** ppv);

	// IClassFactory
	HRESULT __stdcall CreateInstance(IUnknown *pUnknownOuter, REFIID riid, void** ppv);
	HRESULT __stdcall LockServer(BOOL bLock);

	CFactory() : m_cRef(1) { }
	~CFactory() { cout << "Component: CFactory::~CFactory()" << endl; }

private:
	long m_cRef;
};

ULONG CFactory::AddRef()
{
	cout << "Component: CFactory::AddRef() m_cRef = " << m_cRef + 1 << endl;
	return m_cRef++;
}

ULONG CFactory::Release()
{
	cout << "Component: CFactory::Release() m_cRef = " << m_cRef - 1 << endl;
	if(--m_cRef != 0)
		return m_cRef;
	delete this;
	return 0;
}

HRESULT CFactory::QueryInterface(REFIID riid, void **ppv)
{
	if((riid == IID_IUnknown) || (riid == IID_IClassFactory))
	{
		cout << "Component: CFactory::QueryInteface() for IUnknown or IClassFactory " << this << endl;
		*ppv = (IClassFactory *)this;
	}
	else 
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

HRESULT CFactory::CreateInstance(IUnknown *pUnknownOuter, REFIID riid, void **ppv)
{
	if(pUnknownOuter != NULL)
		return CLASS_E_NOAGGREGATION;

	CInsideCOM *pInsideCOM = new CInsideCOM;
	g_pInsideCOM = pInsideCOM;
	cout << "Component: CFactory::CreateInstance() " << pInsideCOM << endl;

	if(pInsideCOM == NULL)
		return E_OUTOFMEMORY;

	HRESULT hr = pInsideCOM->QueryInterface(riid, ppv);
	pInsideCOM->Release();
	return hr;
}

HRESULT CFactory::LockServer(BOOL bLock)
{
	if(bLock)
		g_cServerLocks++;
	else
		g_cServerLocks--;
	return S_OK;
}

void RegisterComponent()
{
	ITypeLib* pTypeLib;
	HRESULT hr = LoadTypeLibEx(L"component.exe", REGKIND_DEFAULT, &pTypeLib);
	pTypeLib->Release();
	RegisterServer("component.exe", CLSID_InsideCOM, "Inside COM Sample", "Component.InsideCOM", "Component.InsideCOM.1", NULL);
}

void CommandLineParameters(int argc, char** argv)
{
	RegisterComponent();
	if(argc < 2)
	{
		cout << "No parameter, but registered anyway..." << endl;
		exit(false);
	}
	char* szToken = strtok(argv[1], "-/"); 
	if(_stricmp(szToken, "RegServer") == 0)
	{
		RegisterComponent();
		cout << "RegServer" << endl;
		exit(true);
	}
	if(_stricmp(szToken, "UnregServer") == 0)
	{
		UnRegisterTypeLib(LIBID_Component, 1, 0, LANG_NEUTRAL, SYS_WIN32);
		UnregisterServer(CLSID_InsideCOM, "Component.InsideCOM", "Component.InsideCOM.1");
		cout << "UnregServer" << endl;
		exit(true);
	}
	if(_stricmp(szToken, "Embedding") != 0)
	{
		cout << "Invalid parameter" << endl;
		exit(false);
	}
}

void TalkToProxy()
{
	while(hStubEvent == 0)
		Sleep(0);

	void* pMem = MapViewOfFile(hFileMap, FILE_MAP_WRITE, 0, 0, 0);
	short method_id = 0;

	while(true)
	{
		WaitForSingleObject(hStubEvent, INFINITE);
		memcpy(&method_id, pMem, sizeof(short));
		switch(method_id) // What method did the proxy call?
		{
		case 1:	// IUnknown::Release
			cout << "Component: Proxy request to call ISum::Release()" << endl;
			CoDisconnectObject(reinterpret_cast<IUnknown*>(g_pInsideCOM), 0);
			g_pInsideCOM->Release();
			return;
		case 2:	// ISum::Sum
			SumTransmit s;
			memcpy(&s, (short*)pMem+1, sizeof(SumTransmit));
			cout << "Component: Proxy request to call ISum::Sum()" << endl;
			g_pInsideCOM->Sum(s.x, s.y, &s.sum);
			memcpy(pMem, &s, sizeof(s));
			SetEvent(hProxyEvent);
		}
	}
}

void main(int argc, char** argv)
{
	CoInitializeEx(NULL, COINIT_MULTITHREADED);
	CommandLineParameters(argc, argv);

	DWORD dwRegister;
	IClassFactory *pIFactory = new CFactory();
	CoRegisterClassObject(CLSID_InsideCOM, pIFactory, CLSCTX_LOCAL_SERVER, REGCLS_MULTIPLEUSE, &dwRegister);
	TalkToProxy();
	CoRevokeClassObject(dwRegister);
	pIFactory->Release();
	CoUninitialize();
	_getch();
}